{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "843d3142-24ca-4bd1-9e31-b55163804fe3",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "dbutils.widgets.text(\"RESOURCE_PREFIX\", \"\")\n",
    "dbutils.widgets.text(\"REDIS_KEY\", \"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "384e5e16-7213-4186-9d04-09d03b155534",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "# Feathr Feature Store on Databricks Demo Notebook\n",
    "\n",
    "This notebook illustrates the use of Feature Store to create a model that predicts NYC Taxi fares. The dataset comes from [here](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page).\n",
    "\n",
    "This notebook is specifically written for Databricks and is relying on some of the Databricks packages such as `dbutils`. The intention here is to provide a \"one click run\" example with minimum configuration. For example:\n",
    "- This notebook skips feature registry which requires running Azure Purview. \n",
    "- To make the online feature query work, you will need to configure the Redis endpoint. \n",
    "\n",
    "The full-fledged notebook can be found from [here](https://github.com/feathr-ai/feathr/blob/main/docs/samples/nyc_taxi_demo.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "c2ce58c7-9263-469a-bbb7-43364ddb07b8",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## Prerequisite\n",
    "\n",
    "To use feathr materialization for online scoring with Redis cache, you may deploy a Redis cluster and set `RESOURCE_PREFIX` and `REDIS_KEY` via Databricks widgets. Note that the deployed Redis host address should be `{RESOURCE_PREFIX}redis.redis.cache.windows.net`. More details about how to deploy the Redis cluster can be found [here](https://feathr-ai.github.io/feathr/how-to-guides/azure-deployment-cli.html#configurure-redis-cluster).\n",
    "\n",
    "To run this notebook, you'll need to install `feathr` pip package. Here, we install notebook-scoped library. For details, please see [Azure Databricks dependency management document](https://learn.microsoft.com/en-us/azure/databricks/libraries/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "4609d7ad-ad74-40fc-b97e-f440a0fa0737",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Install feathr from the latest codes in the repo. You may use `pip install \"feathr[notebook]\"` as well.\n",
    "%pip install \"git+https://github.com/feathr-ai/feathr.git#subdirectory=feathr_project&egg=feathr[notebook]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "c81fa80c-bca6-4ae5-84ad-659a036977bd",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## Notebook Steps\n",
    "\n",
    "This tutorial demonstrates the key capabilities of Feathr, including:\n",
    "\n",
    "1. Install Feathr and necessary dependencies.\n",
    "1. Create shareable features with Feathr feature definition configs.\n",
    "1. Create training data using point-in-time correct feature join\n",
    "1. Train and evaluate a prediction model.\n",
    "1. Materialize feature values for online scoring.\n",
    "\n",
    "The overall data flow is as follows:\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/feathr-ai/feathr/main/docs/images/feature_flow.png\" width=\"800\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "80223a02-631c-40c8-91b3-a037249ffff9",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from datetime import timedelta\n",
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.evaluation import RegressionEvaluator\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.ml.regression import GBTRegressor\n",
    "from pyspark.sql import DataFrame\n",
    "import pyspark.sql.functions as F\n",
    "\n",
    "import feathr\n",
    "from feathr import (\n",
    "    FeathrClient,\n",
    "    # Feature data types\n",
    "    BOOLEAN,\n",
    "    FLOAT,\n",
    "    INT32,\n",
    "    ValueType,\n",
    "    # Feature data sources\n",
    "    INPUT_CONTEXT,\n",
    "    HdfsSource,\n",
    "    # Feature aggregations\n",
    "    TypedKey,\n",
    "    WindowAggTransformation,\n",
    "    # Feature types and anchor\n",
    "    DerivedFeature,\n",
    "    Feature,\n",
    "    FeatureAnchor,\n",
    "    # Materialization\n",
    "    BackfillTime,\n",
    "    MaterializationSettings,\n",
    "    RedisSink,\n",
    "    # Offline feature computation\n",
    "    FeatureQuery,\n",
    "    ObservationSettings,\n",
    ")\n",
    "from feathr.datasets import nyc_taxi\n",
    "from feathr.spark_provider.feathr_configurations import SparkExecutionConfiguration\n",
    "from feathr.utils.config import generate_config\n",
    "from feathr.utils.job_utils import get_result_df\n",
    "\n",
    "\n",
    "print(\n",
    "    f\"\"\"Feathr version: {feathr.__version__}\n",
    "Databricks runtime version: {spark.conf.get(\"spark.databricks.clusterUsageTags.sparkVersion\")}\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "ab35fa01-b392-457e-8fde-7e445a3c39b5",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## 2. Create Shareable Features with Feathr Feature Definition Configs\n",
    "\n",
    "In this notebook, we define all the necessary resource key values for authentication. We use the values passed by the databricks widgets at the top of this notebook. Instead of manually entering the values to the widgets, we can also use [Azure Key Vault](https://azure.microsoft.com/en-us/services/key-vault/) to retrieve them.\n",
    "Please refer to [how-to guide documents for granting key-vault access](https://feathr-ai.github.io/feathr/how-to-guides/azure-deployment-arm.html#3-grant-key-vault-and-synapse-access-to-selected-users-optional) and [Databricks' Azure Key Vault-backed scopes](https://learn.microsoft.com/en-us/azure/databricks/security/secrets/secret-scopes) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "09f93a9f-7b33-4d91-8f31-ee3b20991696",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "RESOURCE_PREFIX = dbutils.widgets.get(\"RESOURCE_PREFIX\")\n",
    "PROJECT_NAME = \"feathr_getting_started\"\n",
    "\n",
    "REDIS_KEY = dbutils.widgets.get(\"REDIS_KEY\")\n",
    "\n",
    "# Use a databricks cluster\n",
    "SPARK_CLUSTER = \"databricks\"\n",
    "\n",
    "# Databricks file system path\n",
    "DATA_STORE_PATH = f\"dbfs:/{PROJECT_NAME}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "331753d6-1850-47b5-ad97-84b7c01d79d1",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Redis credential\n",
    "os.environ[\"REDIS_PASSWORD\"] = REDIS_KEY"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "08bc3b7e-bbf5-4e3a-9978-fe1aef8c1aee",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Configurations\n",
    "\n",
    "Feathr uses a yaml file to define configurations. Please refer to [feathr_config.yaml]( https://github.com//feathr-ai/feathr/blob/main/feathr_project/feathrcli/data/feathr_user_workspace/feathr_config.yaml) for the meaning of each field.\n",
    "\n",
    "In the following cell, we set required databricks credentials automatically by using a databricks notebook context object as well as new job cluster spec."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx = dbutils.notebook.entry_point.getDbutils().notebook().getContext()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "8cd64e3a-376c-48e6-ba41-5197f3591d48",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "config_path = generate_config(\n",
    "    resource_prefix=RESOURCE_PREFIX,\n",
    "    project_name=PROJECT_NAME,\n",
    "    spark_config__spark_cluster=SPARK_CLUSTER,\n",
    "    # You may set an existing cluster id here, but Databricks recommend to use new clusters for greater reliability.\n",
    "    databricks_cluster_id=None,  # Set None to create a new job cluster\n",
    "    databricks_workspace_token_value=ctx.apiToken().get(),\n",
    "    spark_config__databricks__workspace_instance_url=f\"https://{ctx.tags().get('browserHostName').get()}\",\n",
    ")\n",
    "\n",
    "with open(config_path, \"r\") as f:\n",
    "    print(f.read())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "58d22dc1-7590-494d-94ca-3e2488c31c8e",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "All the configurations can be overwritten by environment variables with concatenation of `__` for different layers of the config file. For example, `feathr_runtime_location` for databricks config can be overwritten by setting `spark_config__databricks__feathr_runtime_location` environment variable."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "3fef7f2f-df19-4f53-90a5-ff7999ed983d",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Initialize Feathr Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "9713a2df-c7b2-4562-88b0-b7acce3cc43a",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "client = FeathrClient(config_path=config_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "c3b64bda-d42c-4a64-b976-0fb604cf38c5",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### View the NYC taxi fare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "c4ccd7b3-298a-4e5a-8eec-b7e309db393e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "DATA_FILE_PATH = str(Path(DATA_STORE_PATH, \"nyc_taxi.csv\"))\n",
    "\n",
    "# Download the data file\n",
    "df_raw = nyc_taxi.get_spark_df(spark=spark, local_cache_path=DATA_FILE_PATH)\n",
    "df_raw.limit(5).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "7430c942-64e5-4b70-b823-16ce1d1b3cee",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Defining features with Feathr\n",
    "\n",
    "In Feathr, a feature is viewed as a function, mapping a key and timestamp to a feature value. For more details, please see [Feathr Feature Definition Guide](https://github.com/feathr-ai/feathr/blob/main/docs/concepts/feature-definition.md).\n",
    "\n",
    "* The feature key (a.k.a. entity id) identifies the subject of feature, e.g. a user_id or location_id.\n",
    "* The feature name is the aspect of the entity that the feature is indicating, e.g. the age of the user.\n",
    "* The feature value is the actual value of that aspect at a particular time, e.g. the value is 30 at year 2022.\n",
    "\n",
    "Note that, in some cases, a feature could be just a transformation function that has no entity key or timestamp involved, e.g. *the day of week of the request timestamp*.\n",
    "\n",
    "There are two types of features -- anchored features and derivated features:\n",
    "\n",
    "* **Anchored features**: Features that are directly extracted from sources. Could be with or without aggregation. \n",
    "* **Derived features**: Features that are computed on top of other features.\n",
    "\n",
    "#### Define anchored features\n",
    "\n",
    "A feature source is needed for anchored features that describes the raw data in which the feature values are computed from. A source value should be either `INPUT_CONTEXT` (the features that will be extracted from the observation data directly) or `feathr.source.Source` object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "75b8d2ed-84df-4446-ae07-5f715434f3ea",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "TIMESTAMP_COL = \"lpep_dropoff_datetime\"\n",
    "TIMESTAMP_FORMAT = \"yyyy-MM-dd HH:mm:ss\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "93abbcc2-562b-47e4-ad4c-1fedd7cc64df",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# We define f_trip_distance and f_trip_time_duration features separately\n",
    "# so that we can reuse them later for the derived features.\n",
    "f_trip_distance = Feature(\n",
    "    name=\"f_trip_distance\",\n",
    "    feature_type=FLOAT,\n",
    "    transform=\"trip_distance\",\n",
    ")\n",
    "f_trip_time_duration = Feature(\n",
    "    name=\"f_trip_time_duration\",\n",
    "    feature_type=FLOAT,\n",
    "    transform=\"cast_float((to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime)) / 60)\",\n",
    ")\n",
    "\n",
    "features = [\n",
    "    f_trip_distance,\n",
    "    f_trip_time_duration,\n",
    "    Feature(\n",
    "        name=\"f_is_long_trip_distance\",\n",
    "        feature_type=BOOLEAN,\n",
    "        transform=\"trip_distance > 30.0\",\n",
    "    ),\n",
    "    Feature(\n",
    "        name=\"f_day_of_week\",\n",
    "        feature_type=INT32,\n",
    "        transform=\"dayofweek(lpep_dropoff_datetime)\",\n",
    "    ),\n",
    "    Feature(\n",
    "        name=\"f_day_of_month\",\n",
    "        feature_type=INT32,\n",
    "        transform=\"dayofmonth(lpep_dropoff_datetime)\",\n",
    "    ),\n",
    "    Feature(\n",
    "        name=\"f_hour_of_day\",\n",
    "        feature_type=INT32,\n",
    "        transform=\"hour(lpep_dropoff_datetime)\",\n",
    "    ),\n",
    "]\n",
    "\n",
    "# After you have defined features, bring them together to build the anchor to the source.\n",
    "feature_anchor = FeatureAnchor(\n",
    "    name=\"feature_anchor\",\n",
    "    source=INPUT_CONTEXT,  # Pass through source, i.e. observation data.\n",
    "    features=features,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "728d2d5f-c11f-4941-bdc5-48507f5749f1",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "We can define the source with a preprocessing python function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "3cc59a0e-a41b-480e-a84e-ca5443d63143",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "def preprocessing(df: DataFrame) -> DataFrame:\n",
    "    import pyspark.sql.functions as F\n",
    "\n",
    "    df = df.withColumn(\n",
    "        \"fare_amount_cents\", (F.col(\"fare_amount\") * 100.0).cast(\"float\")\n",
    "    )\n",
    "    return df\n",
    "\n",
    "\n",
    "batch_source = HdfsSource(\n",
    "    name=\"nycTaxiBatchSource\",\n",
    "    path=DATA_FILE_PATH,\n",
    "    event_timestamp_column=TIMESTAMP_COL,\n",
    "    preprocessing=preprocessing,\n",
    "    timestamp_format=TIMESTAMP_FORMAT,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "46f863c4-bb81-434a-a448-6b585031a221",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "For the features with aggregation, the supported functions are as follows:\n",
    "\n",
    "| Aggregation Function | Input Type | Description |\n",
    "| --- | --- | --- |\n",
    "|SUM, COUNT, MAX, MIN, AVG\t|Numeric|Applies the the numerical operation on the numeric inputs. |\n",
    "|MAX_POOLING, MIN_POOLING, AVG_POOLING\t| Numeric Vector | Applies the max/min/avg operation on a per entry bassis for a given a collection of numbers.|\n",
    "|LATEST| Any |Returns the latest not-null values from within the defined time window |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "a373ecbe-a040-4cd3-9d87-0d5f4c5ba553",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "agg_key = TypedKey(\n",
    "    key_column=\"DOLocationID\",\n",
    "    key_column_type=ValueType.INT32,\n",
    "    description=\"location id in NYC\",\n",
    "    full_name=\"nyc_taxi.location_id\",\n",
    ")\n",
    "\n",
    "agg_window = \"90d\"\n",
    "\n",
    "# Anchored features with aggregations\n",
    "agg_features = [\n",
    "    Feature(\n",
    "        name=\"f_location_avg_fare\",\n",
    "        key=agg_key,\n",
    "        feature_type=FLOAT,\n",
    "        transform=WindowAggTransformation(\n",
    "            agg_expr=\"fare_amount_cents\",\n",
    "            agg_func=\"AVG\",\n",
    "            window=agg_window,\n",
    "        ),\n",
    "    ),\n",
    "    Feature(\n",
    "        name=\"f_location_max_fare\",\n",
    "        key=agg_key,\n",
    "        feature_type=FLOAT,\n",
    "        transform=WindowAggTransformation(\n",
    "            agg_expr=\"fare_amount_cents\",\n",
    "            agg_func=\"MAX\",\n",
    "            window=agg_window,\n",
    "        ),\n",
    "    ),\n",
    "]\n",
    "\n",
    "agg_feature_anchor = FeatureAnchor(\n",
    "    name=\"agg_feature_anchor\",\n",
    "    source=batch_source,  # External data source for feature. Typically a data table.\n",
    "    features=agg_features,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "149f85e2-fa3c-4895-b0c5-de5543ca9b6d",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Define derived features\n",
    "\n",
    "We also define a derived feature, `f_trip_time_distance`, from the anchored features `f_trip_distance` and `f_trip_time_duration` as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "05633bc3-9118-449b-9562-45fc437576c2",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "derived_features = [\n",
    "    DerivedFeature(\n",
    "        name=\"f_trip_time_distance\",\n",
    "        feature_type=FLOAT,\n",
    "        input_features=[\n",
    "            f_trip_distance,\n",
    "            f_trip_time_duration,\n",
    "        ],\n",
    "        transform=\"f_trip_distance / f_trip_time_duration\",\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "ad102c45-586d-468c-85f0-9454401ef10b",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Build features\n",
    "\n",
    "Finally, we build the features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "91bb5ebb-87e4-470b-b8eb-1c89b351740e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "client.build_features(\n",
    "    anchor_list=[feature_anchor, agg_feature_anchor],\n",
    "    derived_feature_list=derived_features,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "632d5f46-f9e2-41a8-aab7-34f75206e2aa",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## 3. Create Training Data Using Point-in-Time Correct Feature Join\n",
    "\n",
    "After the feature producers have defined the features (as described in the Feature Definition part), the feature consumers may want to consume those features. Feature consumers will use observation data to query from different feature tables using Feature Query.\n",
    "\n",
    "To create a training dataset using Feathr, one needs to provide a feature join configuration file to specify\n",
    "what features and how these features should be joined to the observation data. \n",
    "\n",
    "To learn more on this topic, please refer to [Point-in-time Correctness](https://github.com//feathr-ai/feathr/blob/main/docs/concepts/point-in-time-join.md)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "02feabc9-2f2f-43e8-898d-b28082798e98",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "feature_names = [feature.name for feature in features + agg_features + derived_features]\n",
    "feature_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "e438e6d8-162e-4aa3-b3b3-9d1f3b0d2b7f",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "DATA_FORMAT = \"parquet\"\n",
    "offline_features_path = str(\n",
    "    Path(DATA_STORE_PATH, \"feathr_output\", f\"features.{DATA_FORMAT}\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "67e81466-c736-47ba-b122-e640642c01cf",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Features that we want to request. Can use a subset of features\n",
    "query = FeatureQuery(\n",
    "    feature_list=feature_names,\n",
    "    key=agg_key,\n",
    ")\n",
    "settings = ObservationSettings(\n",
    "    observation_path=DATA_FILE_PATH,\n",
    "    event_timestamp_column=TIMESTAMP_COL,\n",
    "    timestamp_format=TIMESTAMP_FORMAT,\n",
    ")\n",
    "client.get_offline_features(\n",
    "    observation_settings=settings,\n",
    "    feature_query=query,\n",
    "    # Note, execution_configurations argument only works when using a new job cluster\n",
    "    # For more details, see https://feathr-ai.github.io/feathr/how-to-guides/feathr-job-configuration.html\n",
    "    execution_configurations=SparkExecutionConfiguration(\n",
    "        {\n",
    "            \"spark.feathr.outputFormat\": DATA_FORMAT,\n",
    "        }\n",
    "    ),\n",
    "    output_path=offline_features_path,\n",
    ")\n",
    "\n",
    "client.wait_job_to_finish(timeout_sec=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "9871af55-25eb-41ee-a58a-fda74b1a174e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Show feature results\n",
    "df = get_result_df(\n",
    "    spark=spark,\n",
    "    client=client,\n",
    "    data_format=\"parquet\",\n",
    "    res_url=offline_features_path,\n",
    ")\n",
    "df.select(feature_names).limit(5).toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "dcbf17fc-7f79-4a65-a3af-9cffbd0b5d1f",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## 4. Train and Evaluate a Prediction Model\n",
    "\n",
    "After generating all the features, we train and evaluate a machine learning model to predict the NYC taxi fare prediction. In this example, we use Spark MLlib's [GBTRegressor](https://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression).\n",
    "\n",
    "Note that designing features, training prediction models and evaluating them are an iterative process where the models' performance maybe used to modify the features as a part of the modeling process."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "5a226026-1c7b-48db-8f91-88d5c2ddf023",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Load Train and Test Data from the Offline Feature Values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "bd2cdc83-0920-46e8-9454-e5e6e7832ce0",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Train / test split\n",
    "train_df, test_df = (\n",
    "    df.withColumn(  # Dataframe that we generated from get_offline_features call.\n",
    "        \"label\", F.col(\"fare_amount\").cast(\"double\")\n",
    "    )\n",
    "    .where(F.col(\"f_trip_time_duration\") > 0)\n",
    "    .fillna(0)\n",
    "    .randomSplit([0.8, 0.2])\n",
    ")\n",
    "\n",
    "print(f\"Num train samples: {train_df.count()}\")\n",
    "print(f\"Num test samples: {test_df.count()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "6a3e2ab1-5c66-4d27-a737-c5e2af03b1dd",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Build a ML Pipeline\n",
    "\n",
    "Here, we use Spark ML Pipeline to aggregate feature vectors and feed them to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "2a254361-63e9-45b2-8c19-40549762eacb",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Generate a feature vector column for SparkML\n",
    "vector_assembler = VectorAssembler(\n",
    "    inputCols=[x for x in df.columns if x in feature_names],\n",
    "    outputCol=\"features\",\n",
    ")\n",
    "\n",
    "# Define a model\n",
    "gbt = GBTRegressor(\n",
    "    featuresCol=\"features\",\n",
    "    maxIter=100,\n",
    "    maxDepth=5,\n",
    "    maxBins=16,\n",
    ")\n",
    "\n",
    "# Create a ML pipeline\n",
    "ml_pipeline = Pipeline(\n",
    "    stages=[\n",
    "        vector_assembler,\n",
    "        gbt,\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "bef93538-9591-4247-97b6-289d2055b7b1",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Train and Evaluate the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "0c3d5f35-11a3-4644-9992-5860169d8302",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Train a model\n",
    "model = ml_pipeline.fit(train_df)\n",
    "\n",
    "# Make predictions\n",
    "predictions = model.transform(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "1f9b584c-6228-4a02-a6c3-9b8dd2b78091",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Evaluate\n",
    "evaluator = RegressionEvaluator(\n",
    "    labelCol=\"label\",\n",
    "    predictionCol=\"prediction\",\n",
    ")\n",
    "\n",
    "rmse = evaluator.evaluate(predictions, {evaluator.metricName: \"rmse\"})\n",
    "mae = evaluator.evaluate(predictions, {evaluator.metricName: \"mae\"})\n",
    "print(f\"RMSE: {rmse}\\nMAE: {mae}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "25c33abd-6e87-437d-a6a1-86435f065a1e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# predicted fare vs actual fare plots -- will this work for databricks / synapse / local ?\n",
    "predictions_pdf = predictions.select([\"label\", \"prediction\"]).toPandas().reset_index()\n",
    "\n",
    "predictions_pdf.plot(\n",
    "    x=\"index\",\n",
    "    y=[\"label\", \"prediction\"],\n",
    "    style=[\"-\", \":\"],\n",
    "    figsize=(20, 10),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "664d78cc-4a92-430c-9e05-565ba904558e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "predictions_pdf.plot.scatter(\n",
    "    x=\"label\",\n",
    "    y=\"prediction\",\n",
    "    xlim=(0, 100),\n",
    "    ylim=(0, 100),\n",
    "    figsize=(10, 10),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "8a56d165-c813-4ce0-8ae6-9f4d313c463d",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## 5. Materialize Feature Values for Online Scoring\n",
    "\n",
    "While we computed feature values on-the-fly at request time via Feathr, we can pre-compute the feature values and materialize them to offline or online storages such as Redis.\n",
    "\n",
    "Note, only the features anchored to offline data source can be materialized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "751fa72e-8f94-40a1-994e-3e8315b51d37",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "materialized_feature_names = [feature.name for feature in agg_features]\n",
    "materialized_feature_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "4d4699ed-42e6-408f-903d-2f799284f4b6",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "if REDIS_KEY and RESOURCE_PREFIX:\n",
    "    FEATURE_TABLE_NAME = \"nycTaxiDemoFeature\"\n",
    "\n",
    "    # Get the last date from the dataset\n",
    "    backfill_timestamp = (\n",
    "        df_raw.select(\n",
    "            F.to_timestamp(F.col(TIMESTAMP_COL), TIMESTAMP_FORMAT).alias(TIMESTAMP_COL)\n",
    "        )\n",
    "        .agg({TIMESTAMP_COL: \"max\"})\n",
    "        .collect()[0][0]\n",
    "    )\n",
    "\n",
    "    # Time range to materialize\n",
    "    backfill_time = BackfillTime(\n",
    "        start=backfill_timestamp,\n",
    "        end=backfill_timestamp,\n",
    "        step=timedelta(days=1),\n",
    "    )\n",
    "\n",
    "    # Destinations:\n",
    "    # For online store,\n",
    "    redis_sink = RedisSink(table_name=FEATURE_TABLE_NAME)\n",
    "\n",
    "    # For offline store,\n",
    "    # adls_sink = HdfsSink(output_path=)\n",
    "\n",
    "    settings = MaterializationSettings(\n",
    "        name=FEATURE_TABLE_NAME + \".job\",  # job name\n",
    "        backfill_time=backfill_time,\n",
    "        sinks=[redis_sink],  # or adls_sink\n",
    "        feature_names=materialized_feature_names,\n",
    "    )\n",
    "\n",
    "    client.materialize_features(\n",
    "        settings=settings,\n",
    "        # Note, execution_configurations argument only works when using a new job cluster\n",
    "        execution_configurations={\"spark.feathr.outputFormat\": \"parquet\"},\n",
    "    )\n",
    "\n",
    "    client.wait_job_to_finish(timeout_sec=5000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "5aa13acd-58ec-4fc2-86bb-dc1d9951ebb9",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "Now, you can retrieve features for online scoring as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "424bc9eb-a47f-4b46-be69-8218d55e66ad",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "if REDIS_KEY and RESOURCE_PREFIX:\n",
    "    # Note, to get a single key, you may use client.get_online_features instead\n",
    "    materialized_feature_values = client.multi_get_online_features(\n",
    "        feature_table=FEATURE_TABLE_NAME,\n",
    "        keys=[\"239\", \"265\"],\n",
    "        feature_names=materialized_feature_names,\n",
    "    )\n",
    "    materialized_feature_values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "3596dc71-a363-4b6a-a169-215c89978558",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "b5fb292e-bbb6-4dd7-8e79-c62d9533e820",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Remove temporary files\n",
    "dbutils.fs.rm(\"dbfs:/tmp/\", recurse=True)"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 4
   },
   "notebookName": "databricks_quickstart_nyc_taxi_demo",
   "notebookOrigID": 2365994027381987,
   "widgets": {
    "REDIS_KEY": {
     "currentValue": "",
     "nuid": "d39ce0d5-bcfe-47ef-b3d9-eff67e5cdeca",
     "widgetInfo": {
      "defaultValue": "",
      "label": null,
      "name": "REDIS_KEY",
      "options": {
       "validationRegex": null,
       "widgetType": "text"
      },
      "widgetType": "text"
     }
    },
    "RESOURCE_PREFIX": {
     "currentValue": "",
     "nuid": "87a26035-86fc-4dbd-8dd0-dc546c1c63c1",
     "widgetInfo": {
      "defaultValue": "",
      "label": null,
      "name": "RESOURCE_PREFIX",
      "options": {
       "validationRegex": null,
       "widgetType": "text"
      },
      "widgetType": "text"
     }
    }
   }
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('feathr')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "e34a1a57d2e174682770a82d94a178aa36d3ccfaa21227c5d2308e319b7ae532"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
